import torch
import torch.nn as nn
from models import BasicModule
from sentence_transformers import SentenceTransformer, util


class UniTransformer(BasicModule):
    def __init__(self, cfg):
        super(UniTransformer, self).__init__()
        # self
        self.transformer_model = SentenceTransformer('clip-ViT-B-32')
        self.linear1 = nn.Linear(cfg.text_size, cfg.joint_emb_size)
        self.linear2 = nn.Linear(cfg.img_size, cfg.joint_emb_size)
        self.dropout = nn.Dropout(cfg.dropout)
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.embed_size)
        self.out_dim = cfg.out_dim
        self.out = nn.Linear(cfg.hidden_size, cfg.vocab_size)
        self.softmax = nn.Softmax(dim=1)

    # text embedder
    def text_encoder(self, sequence):
        text_emb = self.transformer_model.encode(sequence)
        return text_emb

    # imag embedder
    def imag_encoder(self, img):
        img_emb = self.transformer_model.encode(img)
        return img_emb

    # uts encoder
    def uts_encoder(self, text_feat, img_feat):
        batch_size = text_feat.size(0)
        t_vec = self.linear1(text_feat)  # (batch, joint_emb_size)
        i_vec = self.linear2(img_feat)  # (batch, joint_emb_size)
        output = self.transformer_model(t_vec, i_vec)
        output = self.dropout(output)  # (batch, joint_emb_size)
        output = output.view(batch_size, 1, self.out_dim, -1)
        # (batch, 1, mfb_out_dim, mfb_factor_num)
        output = torch.sum(output, 3, keepdim=True)  # (batch, 1, mfb_out_dim, 1)
        out = torch.squeeze(output)  # (batch, mfb_out_dim, mfb_factor_num)
        return out

    # fair layer
    def align_reason(self, text_feat, img_feat, true):
        table = torch.tensor(text_feat.size, img_feat.size)
        result = self.align(text_feat, img_feat)
        NLL_loss = self.softmax(result, true)
        return NLL_loss, table

    # htr decoder
    def htr_decoder(self, word, hidden):
        batch_size = hidden.size(1)
        embed = self.embedding(word[0]).view(1, batch_size, -1)
        output, hidden = self.transformer_model(embed, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def forward(self, text, img):
        input_text = self.text_encoder(text)
        input_img = self.imag_encoder(img)
        out = self.uni_fusion(input_text, input_img)
        out = self.LSTM_decoder()


